import argparse
import torch
import numpy as np
import torch.optim as optim
from utils import KpiReader
from models import StackedVAGT
from logger import Logger


class Tester(object):
    def __init__(self, model, test, testloader, log_path='log_tester', log_file='loss', device=torch.device('cpu'),
                 learning_rate=0.0002, nsamples=None, sample_path=None, checkpoints=None):
        self.model = model
        self.model.to(device)
        self.device = device
        self.test = test
        self.testloader = testloader
        self.log_path = log_path
        self.log_file = log_file
        self.learning_rate = learning_rate
        self.nsamples = nsamples
        self.sample_path = sample_path
        self.checkpoints = checkpoints
        self.start_epoch = 0
        self.optimizer = optim.Adam(self.model.parameters(), self.learning_rate)
        self.epoch_losses = []
        self.logger = Logger(self.log_path, self.log_file)
        self.loss = {}

    def load_checkpoint(self, start_ep):
        try:
            print("Loading Chechpoint from ' {} '".format(self.checkpoints + '_epochs{}.pth'.format(start_ep)))
            checkpoint = torch.load(self.checkpoints + '_epochs{}.pth'.format(start_ep))
            self.start_epoch = checkpoint['epoch']
            self.model.beta = checkpoint['beta']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.epoch_losses = checkpoint['losses']
            print("Resuming Training From Epoch {}".format(self.start_epoch))
        except:
            print("No Checkpoint Exists At '{}', Starting Fresh Training".format(
                self.checkpoints + '_epochs{}.pth'.format(start_ep)))
            self.start_epoch = 0

    def model_test(self):
        self.model.eval()
        for i, dataitem in enumerate(self.testloader, 1):
            timestamps, labels, data = dataitem
            data = data.to(self.device)
            z_posterior_forward_list, \
            z_mean_posterior_forward_list, \
            z_logvar_posterior_forward_list, \
            z_mean_prior_forward_list, \
            z_logvar_prior_forward_list, \
            x_mu_list, \
            x_logsigma_list = self.forward_test(data)

            x_mu_list = [x_mu_list]
            x_logsigma_list = [x_logsigma_list]
            z_posterior_forward_list = [z_posterior_forward_list]

            last_timestamp = timestamps[-1, -1, -1, -1]
            label_last_timestamp_tensor = labels[-1, -1, -1, -1]
            anomaly_index = (label_last_timestamp_tensor.numpy() == 1)
            anomaly_nums = len(label_last_timestamp_tensor.numpy()[anomaly_index])
            if anomaly_nums >= 1:
                isanomaly = "Anomaly"
            else:
                isanomaly = "Normaly"
            llh_last_timestamp = self.loglikelihood_last_timestamp(data[-1, -1, -1, :, -1],
                                                                   x_mu_list[-1][-1, -1, -1, :, -1],
                                                                   x_logsigma_list[-1][-1, -1, -1, :, -1])
            z_posterior_forward_last_timestamp_list = []
            L = len(z_posterior_forward_list)

            for l in range(L):
                z_posterior_forward_last_timestamp = z_posterior_forward_list[l][-1, -1, :].to(
                    torch.device('cpu')).numpy().tolist()
                z_posterior_forward_last_timestamp_list.append(z_posterior_forward_last_timestamp)

            T = int(timestamps.shape[1])
            verified_t = int(T / 2)
            verified_timestamp = timestamps[-1, verified_t, -1, -1]
            label_verified_timestamp_tensor = labels[-1, verified_t, -1, -1]
            verified_anomaly_index = (label_verified_timestamp_tensor.numpy() == 1)
            verified_anomaly_nums = len(label_verified_timestamp_tensor.numpy()[verified_anomaly_index])
            if verified_anomaly_nums >= 1:
                verified_isanomaly = "Anomaly"
            else:
                verified_isanomaly = "Normaly"

            llh_verified_timestamp = self.loglikelihood_last_timestamp(data[-1, verified_t, -1, :, -1],
                                                                       x_mu_list[-1][-1, verified_t, -1, :, -1],
                                                                       x_logsigma_list[-1][-1, verified_t, -1, :, -1])
            z_posterior_forward_verified_timestamp_list = []
            L = len(z_posterior_forward_list)

            for l in range(L):
                z_posterior_forward_verified_timestamp = z_posterior_forward_list[l][-1, verified_t, :].to(
                    torch.device('cpu')).numpy().tolist()
                z_posterior_forward_verified_timestamp_list.append(z_posterior_forward_verified_timestamp)

            self.loss['Last_timestamp'] = last_timestamp.item()
            self.loss['Llh_Lt'] = llh_last_timestamp.item()
            self.loss['IA'] = isanomaly

            self.loss['llh_xz_lt'] = llh_last_timestamp.item()
            self.loss['llh_z_lt'] = 0.0
            for ly in range(L - 1):
                self.loss['llh_z_lt_{}'.format(ly)] = 0.0
            for ly in range(L):
                self.loss['zf_lt_{}'.format(ly)] = z_posterior_forward_last_timestamp_list[ly]

            self.loss['Verified_timestamp'] = verified_timestamp.item()
            self.loss['Llh_verified'] = llh_verified_timestamp.item()
            self.loss['IA_verified'] = verified_isanomaly

            self.loss['llh_xz_verified'] = llh_verified_timestamp.item()
            self.loss['llh_z_verified'] = 0.0
            for ly in range(L - 1):
                self.loss['llh_z_verified_{}'.format(ly)] = 0.0
            for ly in range(L):
                self.loss['zf_verified_{}'.format(ly)] = z_posterior_forward_verified_timestamp_list[ly]
            self.logger.log_tester(self.loss, L)

        print("Testing is complete!")

    def forward_test(self, data):
        with torch.no_grad():
            z_posterior_forward_list, \
            z_mean_posterior_forward_list, \
            z_logvar_posterior_forward_list, \
            z_mean_prior_forward_list, \
            z_logvar_prior_forward_list, \
            x_mu_list, \
            x_logsigma_list = self.model(data)
            return z_posterior_forward_list, z_mean_posterior_forward_list, z_logvar_posterior_forward_list, \
                   z_mean_prior_forward_list, z_logvar_prior_forward_list, x_mu_list, x_logsigma_list

    def loglikelihood_last_timestamp(self, x, recon_x_mu, recon_x_logsigma):
        llh = -0.5 * torch.sum(torch.pow(((x.float() - recon_x_mu.float()) / torch.exp(recon_x_logsigma.float())),
                                         2) + 2 * recon_x_logsigma.float() + np.log(np.pi * 2))
        return llh


def main():
    import os
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

    parser = argparse.ArgumentParser()
    # GPU
    parser.add_argument('--gpu_id', type=int, default=0)
    # Dataset options
    parser.add_argument('--dataset_path', type=str, default='../data_preprocess/dataPreprocessed/test/website-1-5')
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--x_dim', type=int, default=307)
    parser.add_argument('--win_len', type=int, default=12)
    # Model options for VAGT
    parser.add_argument('--z_dim', type=int, default=15)
    parser.add_argument('--h_dim', type=int, default=20)
    parser.add_argument('--n_head', type=int, default=8)
    parser.add_argument('--layer_xz', type=int, default=1)
    parser.add_argument('--layer_h', type=int, default=3)
    parser.add_argument('--q_len', type=int, default=1, help='for conv1D padding in Transformer')
    parser.add_argument('--embd_h', type=int, default=128)
    parser.add_argument('--embd_s', type=int, default=256)
    parser.add_argument('--vocab_len', type=int, default=256)
    # Training options for VAGT
    parser.add_argument('--dropout', type=float, default=0.2)
    parser.add_argument('--learning_rate', type=float, default=0.0002)
    parser.add_argument('--beta', type=float, default=0.0)
    parser.add_argument('--max_beta', type=float, default=1.0)
    parser.add_argument('--anneal_rate', type=float, default=0.05)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--start_epoch', type=int, default=0)
    parser.add_argument('--checkpoints_interval', type=int, default=5)
    parser.add_argument('--checkpoints_path', type=str, default='model/website-1-5')
    parser.add_argument('--checkpoints_file', type=str, default='')
    parser.add_argument('--log_path', type=str, default='log_tester/website-1-5')
    parser.add_argument('--log_file', type=str, default='')

    args = parser.parse_args()

    # Set up GPU
    if torch.cuda.is_available() and args.gpu_id >= 0:
        device = torch.device('cuda:%d' % args.gpu_id)
    else:
        device = torch.device('cpu')

    # For config checking
    if not os.path.exists(args.dataset_path):
        raise ValueError('Unknown dataset path: {}'.format(args.dataset_path))

    if not os.path.exists(args.log_path):
        os.makedirs(args.log_path)

    if not os.path.exists(args.checkpoints_path):
        os.makedirs(args.checkpoints_path)

    # TODO Saving path names, for updating later...
    if args.checkpoints_file == '':
        args.checkpoints_file = 'x_dim-{}_z_dim-{}_h_dim-{}_layer_xz-{}_layer_h-{}_embd_h-{}_n_head-{}_' \
                                'win_len-{}_q_len-{}_vocab_len-{}'.format(args.x_dim, args.z_dim, args.h_dim,
                                                                          args.layer_xz, args.layer_h, args.embd_h,
                                                                          args.n_head, args.win_len, args.q_len,
                                                                          args.vocab_len)
    if args.log_file == '':
        args.log_file = 'x_dim-{}_z_dim-{}_h_dim-{}_layer_xz-{}_layer_h-{}_embd_h-{}_n_head-{}_win_len-{}_' \
                        'q_len-{}_vocab_len-{}_epochs-{}_loss'.format(args.x_dim, args.z_dim, args.h_dim, args.layer_xz,
                                                       args.layer_h, args.embd_h, args.n_head, args.win_len,
                                                       args.q_len, args.vocab_len, args.epochs)

    # For training dataset
    kpi_value_test = KpiReader(args.dataset_path)
    test_loader = torch.utils.data.DataLoader(kpi_value_test, batch_size=args.batch_size,
                                              shuffle=True, num_workers=args.num_workers)

    # For models init
    stackedvagt = StackedVAGT(layer_xz=args.layer_xz, layer_h=args.layer_h, n_head=args.n_head, x_dim=args.x_dim,
                              z_dim=args.z_dim, h_dim=args.h_dim, embd_h=args.embd_h, embd_s=args.embd_s,
                              beta=args.beta, q_len=args.q_len, vocab_len=args.vocab_len, win_len=args.win_len,
                              dropout=args.dropout, anneal_rate=args.anneal_rate, max_beta=args.max_beta,
                              device=device).to(device)
    names = []
    for name, parameters in stackedvagt.named_parameters():
        names.append(name)
        print(name, ':', parameters, parameters.size())
    # Start train
    tester = Tester(stackedvagt, kpi_value_test, test_loader, log_path=args.log_path,
                    log_file=args.log_file, learning_rate=args.learning_rate, device=device,
                    checkpoints=os.path.join(args.checkpoints_path, args.checkpoints_file),
                    nsamples=None, sample_path=None)
    tester.load_checkpoint(args.epochs)
    tester.model_test()
    tester.logger.anomaly_score_plot_llh_x(y_range=[-50, 10])
    tester.logger.anomaly_score_plot_llh_xz(y_range=[-50, 30])
    tester.logger.anomaly_score_plot_llh_z(y_range=[-50, 10])
    tester.logger.anomaly_score_plot_llh_x_verified(y_range=[-50, 10])
    tester.logger.anomaly_score_plot_llh_xz_verified(y_range=[-50, 30])
    tester.logger.anomaly_score_plot_llh_z_verified(y_range=[-50, 10])
    tester.logger._plot_z(1)


if __name__ == "__main__":
    import warnings
    warnings.filterwarnings('ignore')
    main()